{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert the TensorFlow-Trained BERT Model to a PyTorch Model\n",
    "We do this to show HuggingFace's ability to convert models between TensorFlow and PyTorch.\n",
    "\n",
    "We will subsequently deploy the model as a PyTorch model using the TorchServe Inference Container."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name=\"sagemaker\", region_name=region)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    training_job_name\n",
    "except NameError:\n",
    "    print(\"+++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please wait for the Training notebook to finish.\")\n",
    "    print(\"+++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Previous training_job_name: {}\".format(training_job_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download the TensorFlow-Trained Model from S3 to this Notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models_dir = \"./model\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp s3://$bucket/$training_job_name/output/model.tar.gz $models_dir/model.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "import pickle as pkl\n",
    "\n",
    "try:\n",
    "    tar = tarfile.open(\"{}/model.tar.gz\".format(models_dir))\n",
    "    tar.extractall(path=models_dir)\n",
    "    tar.close()\n",
    "except Exception as e:\n",
    "    print(\"[ERROR] in tar operation: {}\".format(e))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al $models_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_model_dir = \"{}/transformers/fine-tuned/\".format(models_dir)\n",
    "\n",
    "!ls -al $transformer_model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cat $transformer_model_dir/config.json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert the TensorFlow Model to PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DistilBertForSequenceClassification  # PyTorch version\n",
    "\n",
    "try:\n",
    "    loaded_pytorch_model = DistilBertForSequenceClassification.from_pretrained(\n",
    "        transformer_model_dir,\n",
    "        id2label={0: 1, 1: 2, 2: 3, 3: 4, 4: 5},\n",
    "        label2id={1: 0, 2: 1, 3: 2, 4: 3, 5: 4},\n",
    "        from_tf=True,\n",
    "    )\n",
    "except Exception as e:\n",
    "    print(\"[ERROR] in loading model {}: \".format(e))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(type(loaded_pytorch_model))\n",
    "print(loaded_pytorch_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save The Transformer/PyTorch Model with `.save_pretrained()`\n",
    "This will create `pytorch_model.bin`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_models_dir = \"./model/transformers/pytorch\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p $pytorch_models_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_pytorch_model.save_pretrained(pytorch_models_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al $pytorch_models_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load and Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DistilBertTokenizer\n",
    "\n",
    "tokenizer = DistilBertTokenizer.from_pretrained(\"distilbert-base-uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_model = DistilBertForSequenceClassification.from_pretrained(pytorch_models_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls -al $pytorch_models_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import DistilBertForSequenceClassification\n",
    "from transformers import DistilBertConfig\n",
    "\n",
    "config = DistilBertConfig.from_json_file(\"{}/config.json\".format(pytorch_models_dir))\n",
    "\n",
    "model_path = \"{}/{}\".format(pytorch_models_dir, \"pytorch_model.bin\")\n",
    "model = DistilBertForSequenceClassification.from_pretrained(model_path, config=config)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import smdebug.pytorch as smd\n",
    "from torch import nn\n",
    "\n",
    "max_seq_length = 64\n",
    "classes = [1, 2, 3, 4, 5]\n",
    "\n",
    "model.eval()\n",
    "\n",
    "input_data = '[{\"features\": [\"This is great!\"]}, \\\n",
    "               {\"features\": [\"This is bad.\"]}]'\n",
    "print(\"input_data: {}\".format(input_data))\n",
    "\n",
    "data_json = json.loads(input_data)\n",
    "print(\"data_json: {}\".format(data_json))\n",
    "\n",
    "predicted_classes = []\n",
    "save_config = smd.SaveConfig(save_interval=1)\n",
    "# hook = smd.Hook(\"/tmp/tensors\", save_config=save_config, include_regex='.*')\n",
    "\n",
    "# hook.register_module(model)\n",
    "\n",
    "for data_json_line in data_json:\n",
    "    print(\"data_json_line: {}\".format(data_json_line))\n",
    "    print(\"type(data_json_line): {}\".format(type(data_json_line)))\n",
    "\n",
    "    review_body = data_json_line[\"features\"][0]\n",
    "    print(\"\"\"review_body: {}\"\"\".format(review_body))\n",
    "\n",
    "    encode_plus_token = tokenizer.encode_plus(\n",
    "        review_body,\n",
    "        max_length=max_seq_length,\n",
    "        add_special_tokens=False,\n",
    "        #        return_token_type_ids=False,\n",
    "        return_token_type_ids=None,\n",
    "        #        pad_to_max_length=True,\n",
    "        padding=True,\n",
    "        return_attention_mask=True,\n",
    "        return_tensors=\"pt\",\n",
    "        truncation=True,\n",
    "    )\n",
    "\n",
    "    input_ids = encode_plus_token[\"input_ids\"]\n",
    "\n",
    "    #    hook._write_raw_tensor_simple(\"input_tokens\", tokenizer.tokenize(review_body))\n",
    "    attention_mask = encode_plus_token[\"attention_mask\"]\n",
    "\n",
    "    output = model(input_ids, attention_mask)\n",
    "    print(\"output: {}\".format(output))\n",
    "\n",
    "    softmax_fn = nn.Softmax(dim=1)\n",
    "    softmax_output = softmax_fn(output[0])\n",
    "    print(\"softmax_output: {}\".format(softmax_output))\n",
    "\n",
    "    _, prediction = torch.max(softmax_output, dim=1)\n",
    "\n",
    "    predicted_class_idx = prediction.item()\n",
    "    predicted_class = classes[predicted_class_idx]\n",
    "    print(\"predicted_class: {}\".format(predicted_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Upload Transformer/PyTorch Model to S3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_pytorch_model_dir_s3_uri = \"s3://{}/model/{}/transformer-pytorch/\".format(bucket, training_job_name)\n",
    "print(transformer_pytorch_model_dir_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mv ./model/transformers/pytorch/pytorch_model.bin ./model/transformers/pytorch/model.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cd ./model/transformers/pytorch/ && tar -cvzf model.tar.gz *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp ./model/transformers/pytorch/model.tar.gz $transformer_pytorch_model_dir_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 ls $transformer_pytorch_model_dir_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store transformer_pytorch_model_dir_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Release Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%html\n",
    "\n",
    "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n",
    "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n",
    "        \n",
    "<script>\n",
    "try {\n",
    "    els = document.getElementsByClassName(\"sm-command-button\");\n",
    "    els[0].click();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}    \n",
    "</script>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "\n",
    "try {\n",
    "    Jupyter.notebook.save_checkpoint();\n",
    "    Jupyter.notebook.session.delete();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
